{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MpkYHwCqk7W-"
      },
      "source": [
        "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xBSdkbmGN2K-"
      },
      "source": [
        "### Copyright notice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_UbO9uhtBSX5"
      },
      "source": [
        "> <p><small><small>Copyright 2025 DeepMind Technologies Limited.</small></p>\n",
        "> <p><small><small>Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href=\"http://www.apache.org/licenses/LICENSE-2.0\">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>\n",
        "> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dNIJkb_FM2Ux"
      },
      "source": [
        "# Welcome to MuJoCo Playground! <a href=\"https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/welcome_to_the_playground.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a>\n",
        "\n",
        "Welcome to MuJoCo Playground, we're excited to meet you! MuJoCo Playground contains a comprehensive suite of environments for reinforcement learning and robotics research. In this notebook, we'll give a tour of [DM Control Suite](https://github.com/google-deepmind/dm_control/tree/main/dm_control/suite) environments that were ported to run on GPU via [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html).\n",
        "\n",
        "**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime > Change runtime type\".\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xqo7pyX-n72M",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Install pre-requisites\n",
        "!pip install mujoco\n",
        "!pip install mujoco_mjx\n",
        "!pip install brax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "IbZxYDxzoz5R"
      },
      "outputs": [],
      "source": [
        "# @title Check if MuJoCo installation was successful\n",
        "\n",
        "import distutils.util\n",
        "import os\n",
        "import subprocess\n",
        "\n",
        "if subprocess.run('nvidia-smi').returncode:\n",
        "  raise RuntimeError(\n",
        "      'Cannot communicate with GPU. '\n",
        "      'Make sure you are using a GPU Colab runtime. '\n",
        "      'Go to the Runtime menu and select Choose runtime type.'\n",
        "  )\n",
        "\n",
        "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
        "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
        "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
        "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
        "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
        "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
        "  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
        "    f.write(\"\"\"{\n",
        "    \"file_format_version\" : \"1.0.0\",\n",
        "    \"ICD\" : {\n",
        "        \"library_path\" : \"libEGL_nvidia.so.0\"\n",
        "    }\n",
        "}\n",
        "\"\"\")\n",
        "\n",
        "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
        "print('Setting environment variable to use GPU rendering:')\n",
        "%env MUJOCO_GL=egl\n",
        "\n",
        "try:\n",
        "  print('Checking that the installation succeeded:')\n",
        "  import mujoco\n",
        "\n",
        "  mujoco.MjModel.from_xml_string('<mujoco/>')\n",
        "except Exception as e:\n",
        "  raise e from RuntimeError(\n",
        "      'Something went wrong during installation. Check the shell output above '\n",
        "      'for more information.\\n'\n",
        "      'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
        "      'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
        "  )\n",
        "\n",
        "print('Installation successful.')\n",
        "\n",
        "# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n",
        "xla_flags = os.environ.get('XLA_FLAGS', '')\n",
        "xla_flags += ' --xla_gpu_triton_gemm_any=True'\n",
        "os.environ['XLA_FLAGS'] = xla_flags"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "T5f4w3Kq2X14"
      },
      "outputs": [],
      "source": [
        "# @title Import packages for plotting and creating graphics\n",
        "import itertools\n",
        "import time\n",
        "from typing import Callable, List, NamedTuple, Optional, Union\n",
        "import numpy as np\n",
        "\n",
        "# Graphics and plotting.\n",
        "print(\"Installing mediapy:\")\n",
        "!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
        "!pip install -q mediapy\n",
        "import mediapy as media\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# More legible printing from numpy.\n",
        "np.set_printoptions(precision=3, suppress=True, linewidth=100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ObF1UXrkb0Nd"
      },
      "outputs": [],
      "source": [
        "# @title Import MuJoCo, MJX, and Brax\n",
        "from datetime import datetime\n",
        "import functools\n",
        "import os\n",
        "from typing import Any, Dict, Sequence, Tuple, Union\n",
        "from brax import base\n",
        "from brax import envs\n",
        "from brax import math\n",
        "from brax.base import Base, Motion, Transform\n",
        "from brax.base import State as PipelineState\n",
        "from brax.envs.base import Env, PipelineEnv, State\n",
        "from brax.io import html, mjcf, model\n",
        "from brax.mjx.base import State as MjxState\n",
        "from brax.training.agents.ppo import networks as ppo_networks\n",
        "from brax.training.agents.ppo import train as ppo\n",
        "from brax.training.agents.sac import networks as sac_networks\n",
        "from brax.training.agents.sac import train as sac\n",
        "from etils import epath\n",
        "from flax import struct\n",
        "from flax.training import orbax_utils\n",
        "from IPython.display import HTML, clear_output\n",
        "import jax\n",
        "from jax import numpy as jp\n",
        "from matplotlib import pyplot as plt\n",
        "import mediapy as media\n",
        "from ml_collections import config_dict\n",
        "import mujoco\n",
        "from mujoco import mjx\n",
        "import numpy as np\n",
        "from orbax import checkpoint as ocp"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Install MuJoCo Playground\n",
        "!pip install playground"
      ],
      "metadata": {
        "cellView": "form",
        "id": "UoTLSx4cFRdy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Intro\n",
        "\n",
        "MuJoCo Playground contains environments for DM Control Suite, Robotic Locomotion, and Robotic Manipulation. You can load any of these environments via the environment registry:"
      ],
      "metadata": {
        "id": "_R01tjWfI-i6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from mujoco_playground import registry\n",
        "env = registry.load('CartpoleBalance')\n",
        "env"
      ],
      "metadata": {
        "id": "kPJeoQeEJBSA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Each environment is also associated with an environment config, which can be overriden if so desired. Let's also load the config:"
      ],
      "metadata": {
        "id": "3e3MYIKnJnQn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "env_cfg = registry.get_default_config('CartpoleBalance')\n",
        "env_cfg"
      ],
      "metadata": {
        "id": "bd8RSkV8JuLe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Notice that the environment config contains `sim_dt` and `ctrl_dt`. Each simulation step runs with a timestep of `sim_dt`. `ctrl_dt` determines how much time passes by for each `env.step`. Thus every `env.step` will step the simulation `ctrl_dt / sim_dt` times.\n",
        "\n",
        "Other parameters worth noting are `vision_config`, which we discuss more about in the vision-based notebooks! For now, we'll stick to privileged observations."
      ],
      "metadata": {
        "id": "pw_y1cVYJzDE"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Rollout\n",
        "\n",
        "Let's now do a simple rollout. This follows closely with the [MJX tutorial notebook](https://colab.sandbox.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb). If you're familiar with MJX, this should be familiar; otherwise feel free to checkout out the MJX tutorial first!"
      ],
      "metadata": {
        "id": "_6D8A3zCKn_K"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PDGmZ6zGLDws"
      },
      "outputs": [],
      "source": [
        "jit_reset = jax.jit(env.reset)\n",
        "jit_step = jax.jit(env.step)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6g7mZ1DyLDws"
      },
      "outputs": [],
      "source": [
        "state = jit_reset(jax.random.PRNGKey(0))\n",
        "rollout = [state]\n",
        "\n",
        "f = 0.5\n",
        "for i in range(env_cfg.episode_length):\n",
        "  action = []\n",
        "  for j in range(env.action_size):\n",
        "    action.append(\n",
        "        jp.sin(\n",
        "            state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size\n",
        "        )\n",
        "    )\n",
        "  action = jp.array(action)\n",
        "  state = jit_step(state, action)\n",
        "  rollout.append(state)\n",
        "\n",
        "frames = env.render(rollout)\n",
        "media.show_video(frames, fps=1.0 / env.dt)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "If you're running the notebook with a GPU instance (which you should), notice that the environment runs and lives on the device! That of course means we can do some large-batch RL on the GPU."
      ],
      "metadata": {
        "id": "gP8sNROjLc_p"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "state.obs.device"
      ],
      "metadata": {
        "id": "WcdKUCt6LDYg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# RL"
      ],
      "metadata": {
        "id": "Thm7nZueM4cz"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "We'll use [brax](https://github.com/google/brax) to train some RL policies, but we show examples with [RSL-RL](https://github.com/google-deepmind/mujoco_playground/tree/main/learning/train_rsl_rl.py) in the main repo via [`python train_rsl_rl.py`](). We encourage `pytorch` users to take a look!\n",
        "\n",
        "For now we'll go ahead with a brax PPO example.\n",
        "\n"
      ],
      "metadata": {
        "id": "bkeGmfKAL2WA"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "MuJoCo Playground comes with a set of hyper-parameters for all available environments via `mujoco_playground.config`. A selection of environments contain configs for RSL-RL and for vision-based brax PPO. Let's load the brax PPO config for CartpoleBalance!"
      ],
      "metadata": {
        "id": "F9IfxRqdmQAI"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from mujoco_playground.config import dm_control_suite_params\n",
        "ppo_params = dm_control_suite_params.brax_ppo_config('CartpoleBalance')\n",
        "ppo_params"
      ],
      "metadata": {
        "id": "B9T_UVZYLDdM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "And let's of course train a policy!"
      ],
      "metadata": {
        "id": "FOaeCdOtNtRi"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vBEEQyY6M5OC"
      },
      "source": [
        "### PPO"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XKFzyP7wM5OD"
      },
      "outputs": [],
      "source": [
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "\n",
        "def progress(num_steps, metrics):\n",
        "  clear_output(wait=True)\n",
        "\n",
        "  times.append(datetime.now())\n",
        "  x_data.append(num_steps)\n",
        "  y_data.append(metrics[\"eval/episode_reward\"])\n",
        "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
        "\n",
        "  plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
        "  plt.ylim([0, 1100])\n",
        "  plt.xlabel(\"# environment steps\")\n",
        "  plt.ylabel(\"reward per episode\")\n",
        "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
        "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
        "\n",
        "  display(plt.gcf())\n",
        "\n",
        "ppo_training_params = dict(ppo_params)\n",
        "network_factory = ppo_networks.make_ppo_networks\n",
        "if \"network_factory\" in ppo_params:\n",
        "  del ppo_training_params[\"network_factory\"]\n",
        "  network_factory = functools.partial(\n",
        "      ppo_networks.make_ppo_networks,\n",
        "      **ppo_params.network_factory\n",
        "  )\n",
        "\n",
        "train_fn = functools.partial(\n",
        "    ppo.train, **dict(ppo_training_params),\n",
        "    network_factory=network_factory,\n",
        "    progress_fn=progress\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FGrlulWbM5OD"
      },
      "outputs": [],
      "source": [
        "from mujoco_playground import wrapper\n",
        "\n",
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=env,\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "If you're familiar with brax, you'll notice that we provided a custom wrapper to the `train_fn`. That's because MuJoCo Playground environments are not compatible with the vanilla brax wrappers. For RSL-RL, we ship wrappers in [`wrapper_torch.py`](https://github.com/google-deepmind/mujoco_playground/tree/main/mujoco_playground/_src/wrapper_torch.py)."
      ],
      "metadata": {
        "id": "S-okEidWPIIg"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Visualize Rollout"
      ],
      "metadata": {
        "id": "mHVmccs-oMSo"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sG5a2FFXoUKw"
      },
      "outputs": [],
      "source": [
        "jit_reset = jax.jit(env.reset)\n",
        "jit_step = jax.jit(env.step)\n",
        "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C_1CY9xDoUKw"
      },
      "outputs": [],
      "source": [
        "rng = jax.random.PRNGKey(42)\n",
        "rollout = []\n",
        "n_episodes = 1\n",
        "\n",
        "for _ in range(n_episodes):\n",
        "  state = jit_reset(rng)\n",
        "  rollout.append(state)\n",
        "  for i in range(env_cfg.episode_length):\n",
        "    act_rng, rng = jax.random.split(rng)\n",
        "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "    state = jit_step(state, ctrl)\n",
        "    rollout.append(state)\n",
        "\n",
        "render_every = 1\n",
        "frames = env.render(rollout[::render_every])\n",
        "rewards = [s.reward for s in rollout]\n",
        "media.show_video(frames, fps=1.0 / env.dt / render_every)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a3NXzZCjTskz"
      },
      "source": [
        "# DM Control Suite - Take a spin!\n",
        "\n",
        "Feel free to now take a spin on any of the DM Control Suite environments! The world is your oyster."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "CilNmELAdeNk"
      },
      "outputs": [],
      "source": [
        "env_name = \"FishSwim\"  # @param [\"AcrobotSwingup\", \"AcrobotSwingupSparse\", \"BallInCup\", \"CartpoleBalance\", \"CartpoleBalanceSparse\", \"CartpoleSwingup\", \"CartpoleSwingupSparse\", \"CheetahRun\", \"FingerSpin\", \"FingerTurnEasy\", \"FingerTurnHard\", \"FishSwim\", \"HopperHop\", \"HopperStand\", \"HumanoidStand\", \"HumanoidWalk\", \"HumanoidRun\", \"PendulumSwingup\", \"PointMass\", \"ReacherEasy\", \"ReacherHard\", \"SwimmerSwimmer6\", \"WalkerRun\", \"WalkerStand\", \"WalkerWalk\"]\n",
        "CAMERAS = {\n",
        "    \"AcrobotSwingup\": \"fixed\",\n",
        "    \"AcrobotSwingupSparse\": \"fixed\",\n",
        "    \"BallInCup\": \"cam0\",\n",
        "    \"CartpoleBalance\": \"fixed\",\n",
        "    \"CartpoleBalanceSparse\": \"fixed\",\n",
        "    \"CartpoleSwingup\": \"fixed\",\n",
        "    \"CartpoleSwingupSparse\": \"fixed\",\n",
        "    \"CheetahRun\": \"side\",\n",
        "    \"FingerSpin\": \"cam0\",\n",
        "    \"FingerTurnEasy\": \"cam0\",\n",
        "    \"FingerTurnHard\": \"cam0\",\n",
        "    \"FishSwim\": \"fixed_top\",\n",
        "    \"HopperHop\": \"cam0\",\n",
        "    \"HopperStand\": \"cam0\",\n",
        "    \"HumanoidStand\": \"side\",\n",
        "    \"HumanoidWalk\": \"side\",\n",
        "    \"HumanoidRun\": \"side\",\n",
        "    \"PendulumSwingup\": \"fixed\",\n",
        "    \"PointMass\": \"cam0\",\n",
        "    \"ReacherEasy\": \"fixed\",\n",
        "    \"ReacherHard\": \"fixed\",\n",
        "    \"SwimmerSwimmer6\": \"tracking1\",\n",
        "    \"WalkerRun\": \"side\",\n",
        "    \"WalkerWalk\": \"side\",\n",
        "    \"WalkerStand\": \"side\",\n",
        "}\n",
        "camera_name = CAMERAS[env_name]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6m1K6y4IOUj_"
      },
      "outputs": [],
      "source": [
        "env_cfg = registry.get_default_config(env_name)\n",
        "env = registry.load(env_name, config=env_cfg)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wLHkrJhAOUj_"
      },
      "source": [
        "## Visualize the environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9Ry7wWduOUkA"
      },
      "outputs": [],
      "source": [
        "jit_reset = jax.jit(env.reset)\n",
        "jit_step = jax.jit(env.step)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WR5DX0SWOUkA"
      },
      "outputs": [],
      "source": [
        "state = jit_reset(jax.random.PRNGKey(0))\n",
        "rollout = [state]\n",
        "\n",
        "f = 0.5\n",
        "for i in range(env_cfg.episode_length):\n",
        "  action = []\n",
        "  for j in range(env.action_size):\n",
        "    action.append(\n",
        "        jp.sin(\n",
        "            state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size\n",
        "        )\n",
        "    )\n",
        "  action = jp.array(action)\n",
        "  state = jit_step(state, action)\n",
        "  rollout.append(state)\n",
        "\n",
        "frames = env.render(rollout)\n",
        "media.show_video(frames, fps=1.0 / env.dt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QuI8_ioCOUkB"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "afAaOGzSoNMd"
      },
      "outputs": [],
      "source": [
        "from mujoco_playground.config import dm_control_suite_params\n",
        "\n",
        "ppo_params = dm_control_suite_params.brax_ppo_config(env_name)\n",
        "sac_params = dm_control_suite_params.brax_sac_config(env_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t5MNLlatbS27"
      },
      "source": [
        "### PPO"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N5KgRSAZbSEX"
      },
      "outputs": [],
      "source": [
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "\n",
        "def progress(num_steps, metrics):\n",
        "  clear_output(wait=True)\n",
        "\n",
        "  times.append(datetime.now())\n",
        "  x_data.append(num_steps)\n",
        "  y_data.append(metrics[\"eval/episode_reward\"])\n",
        "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
        "\n",
        "  plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
        "  plt.ylim([0, 1100])\n",
        "  plt.xlabel(\"# environment steps\")\n",
        "  plt.ylabel(\"reward per episode\")\n",
        "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
        "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
        "\n",
        "  display(plt.gcf())\n",
        "\n",
        "ppo_training_params = dict(ppo_params)\n",
        "network_factory = ppo_networks.make_ppo_networks\n",
        "if \"network_factory\" in ppo_params:\n",
        "  del ppo_training_params[\"network_factory\"]\n",
        "  network_factory = functools.partial(\n",
        "      ppo_networks.make_ppo_networks,\n",
        "      **ppo_params.network_factory\n",
        "  )\n",
        "\n",
        "train_fn = functools.partial(\n",
        "    ppo.train, **dict(ppo_training_params),\n",
        "    network_factory=network_factory,\n",
        "    progress_fn=progress\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YrD6T1VWbSJD"
      },
      "outputs": [],
      "source": [
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=env,\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "80esJi__b6J2"
      },
      "outputs": [],
      "source": [
        "jit_reset = jax.jit(env.reset)\n",
        "jit_step = jax.jit(env.step)\n",
        "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HsbR2AIRb6J2"
      },
      "outputs": [],
      "source": [
        "rng = jax.random.PRNGKey(42)\n",
        "rollout = []\n",
        "n_episodes = 1\n",
        "\n",
        "for _ in range(n_episodes):\n",
        "  state = jit_reset(rng)\n",
        "  rollout.append(state)\n",
        "  for i in range(env_cfg.episode_length):\n",
        "    act_rng, rng = jax.random.split(rng)\n",
        "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "    state = jit_step(state, ctrl)\n",
        "    rollout.append(state)\n",
        "\n",
        "render_every = 1\n",
        "frames = env.render(rollout[::render_every])\n",
        "rewards = [s.reward for s in rollout]\n",
        "media.show_video(frames, fps=1.0 / env.dt / render_every)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HpRJnuuXb7Ax"
      },
      "source": [
        "### SAC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "besM1HxqOUkB"
      },
      "outputs": [],
      "source": [
        "x_data, y_data, y_dataerr = [], [], []\n",
        "times = [datetime.now()]\n",
        "\n",
        "\n",
        "def progress(num_steps, metrics):\n",
        "  clear_output(wait=True)\n",
        "\n",
        "  times.append(datetime.now())\n",
        "  x_data.append(num_steps)\n",
        "  y_data.append(metrics[\"eval/episode_reward\"])\n",
        "  y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
        "\n",
        "  plt.xlim([0, sac_params[\"num_timesteps\"] * 1.25])\n",
        "  plt.ylim([0, 1100])\n",
        "  plt.xlabel(\"# environment steps\")\n",
        "  plt.ylabel(\"reward per episode\")\n",
        "  plt.title(f\"y={y_data[-1]:.3f}\")\n",
        "  plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
        "\n",
        "  display(plt.gcf())\n",
        "\n",
        "sac_training_params = dict(sac_params)\n",
        "network_factory = sac_networks.make_sac_networks\n",
        "if \"network_factory\" in sac_params:\n",
        "  del sac_training_params[\"network_factory\"]\n",
        "  network_factory = functools.partial(\n",
        "      sac_networks.make_sac_networks,\n",
        "      **sac_params.network_factory\n",
        "  )\n",
        "\n",
        "train_fn = functools.partial(\n",
        "    sac.train, **dict(sac_training_params),\n",
        "    network_factory=network_factory,\n",
        "    progress_fn=progress\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XkKerfRjOUkB"
      },
      "outputs": [],
      "source": [
        "make_inference_fn, params, metrics = train_fn(\n",
        "    environment=env,\n",
        "    wrap_env_fn=wrapper.wrap_for_brax_training,\n",
        ")\n",
        "print(f\"time to jit: {times[1] - times[0]}\")\n",
        "print(f\"time to train: {times[-1] - times[1]}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Yw_vewrKOUkB"
      },
      "outputs": [],
      "source": [
        "jit_reset = jax.jit(env.reset)\n",
        "jit_step = jax.jit(env.step)\n",
        "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VJFfqcwoOUkB"
      },
      "outputs": [],
      "source": [
        "rng = jax.random.PRNGKey(0)\n",
        "rollout = []\n",
        "n_episodes = 1\n",
        "\n",
        "for _ in range(n_episodes):\n",
        "  state = jit_reset(rng)\n",
        "  rollout.append(state)\n",
        "  for i in range(env_cfg.episode_length):\n",
        "    act_rng, rng = jax.random.split(rng)\n",
        "    ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
        "    state = jit_step(state, ctrl)\n",
        "    rollout.append(state)\n",
        "\n",
        "render_every = 1\n",
        "frames = env.render(rollout[::render_every])\n",
        "rewards = [s.reward for s in rollout]\n",
        "media.show_video(frames, fps=1.0 / env.dt / render_every)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CBtrAqns35sI"
      },
      "source": [
        "🙌 See you soon!"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "toc_visible": true,
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
